{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "22771bcc-be48-4bc6-906e-e450568a8734",
   "metadata": {},
   "source": [
    "# Few Shot Learning with AutoMM\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/few_shot_learning.ipynb)\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/few_shot_learning.ipynb)\n",
    "\n",
    "\n",
    "In this tutorial we introduce a simple but effective way for few shot classification problems. \n",
    "We present the functionality which leverages the high-quality features from foundation models and uses SVM for few shot classification tasks.\n",
    "Specifically, we extract sample features with pretrained models, and use the features for SVM learning.\n",
    "We show the effectiveness of the foundation-model-followed-by-SVM on a text classification dataset and an image classification dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Few Shot Text Classification\n",
    "### Prepare Text Data\n",
    "We prepare all datasets in the format of `pd.DataFrame` as in many of our tutorials have done.\n",
    "For this tutorial, we'll use a small `MLDoc` dataset for demonstration.\n",
    "The dataset is a text classification dataset, which contains 4 classes and we downsampled the training data to 10 samples per class, a.k.a 10 shots.\n",
    "For more details regarding `MLDoc` please see this [link](https://github.com/facebookresearch/MLDoc)."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c3f7d4164a414ef5"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4995899a-a489-4861-95f1-8312ed83f867",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from autogluon.core.utils.loaders import load_zip\n",
    "\n",
    "download_dir = \"./ag_automm_tutorial_fs_cls\"\n",
    "zip_file = \"https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip\"\n",
    "load_zip.unzip(zip_file, unzip_dir=download_dir)\n",
    "dataset_path = os.path.join(download_dir)\n",
    "train_df = pd.read_csv(f\"{dataset_path}/train.csv\", names=[\"label\", \"text\"])\n",
    "test_df = pd.read_csv(f\"{dataset_path}/test.csv\", names=[\"label\", \"text\"])\n",
    "print(train_df)\n",
    "print(test_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99c4fe3a-3acd-4db1-af27-e8aca839bb66",
   "metadata": {},
   "source": [
    "### Train a Few Shot Classifier\n",
    "In order to perform few shot classification, we need to use the `few_shot_classification` problem type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "393bd8f4-76bc-4889-ade2-e94c9d7374bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal import MultiModalPredictor\n",
    "\n",
    "predictor_fs_text = MultiModalPredictor(\n",
    "    problem_type=\"few_shot_classification\",\n",
    "    label=\"label\",  # column name of the label\n",
    "    eval_metric=\"acc\",\n",
    ")\n",
    "predictor_fs_text.fit(train_df)\n",
    "scores = predictor_fs_text.evaluate(test_df, metrics=[\"acc\", \"f1_macro\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d12e47a0-21a9-4261-9bd2-5c172ba8fc5d",
   "metadata": {},
   "source": [
    "### Compare to the Default Classifier\n",
    "Let's use the default `classification` problem type and compare the performance with the above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2aab3c-fcd7-4220-b1e8-d2799ecb9566",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal import MultiModalPredictor\n",
    "\n",
    "predictor_default_text = MultiModalPredictor(\n",
    "    label=\"label\",\n",
    "    problem_type=\"classification\",\n",
    "    eval_metric=\"acc\",\n",
    ")\n",
    "predictor_default_text.fit(train_data=train_df)\n",
    "scores = predictor_default_text.evaluate(test_df, metrics=[\"acc\", \"f1_macro\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e28dd113-01c9-48ee-8fd9-7475f6b5020e",
   "metadata": {},
   "source": [
    "## Few Shot Image Classification\n",
    "We also provide an example of using `MultiModalPredictor` on a few-shot image classification task.\n",
    "### Load Dataset\n",
    "We use the Stanford Cars dataset for demonstration and have downsampled the training set to have 8 samples per class.\n",
    "The Stanford Cars is an image classification dataset and contains 196 classes.\n",
    "For more information regarding the dataset, please see [here](https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80bfb087-ac7a-4c61-83cc-45e0e030bda3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from autogluon.core.utils.loaders import load_zip, load_s3\n",
    "\n",
    "download_dir = \"./ag_automm_tutorial_fs_cls/stanfordcars/\"\n",
    "zip_file = \"https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip\"\n",
    "train_csv = \"https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv\"\n",
    "test_csv = \"https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv\"\n",
    "\n",
    "load_zip.unzip(zip_file, unzip_dir=download_dir)\n",
    "dataset_path = os.path.join(download_dir)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80bfb087-ac7a-4c61-83cc-45e0e030bda4",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/train.csv\n",
    "!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/test.csv\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80bfb087-ac7a-4c61-83cc-45e0e030bda5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "train_df_raw = pd.read_csv(os.path.join(download_dir, \"train.csv\"))\n",
    "train_df = train_df_raw.drop(\n",
    "        columns=[\n",
    "            \"Source\",\n",
    "            \"Confidence\",\n",
    "            \"XMin\",\n",
    "            \"XMax\",\n",
    "            \"YMin\",\n",
    "            \"YMax\",\n",
    "            \"IsOccluded\",\n",
    "            \"IsTruncated\",\n",
    "            \"IsGroupOf\",\n",
    "            \"IsDepiction\",\n",
    "            \"IsInside\",\n",
    "        ]\n",
    "    )\n",
    "train_df[\"ImageID\"] = download_dir + train_df[\"ImageID\"].astype(str)\n",
    "\n",
    "\n",
    "test_df_raw = pd.read_csv(os.path.join(download_dir, \"test.csv\"))\n",
    "test_df = test_df_raw.drop(\n",
    "        columns=[\n",
    "            \"Source\",\n",
    "            \"Confidence\",\n",
    "            \"XMin\",\n",
    "            \"XMax\",\n",
    "            \"YMin\",\n",
    "            \"YMax\",\n",
    "            \"IsOccluded\",\n",
    "            \"IsTruncated\",\n",
    "            \"IsGroupOf\",\n",
    "            \"IsDepiction\",\n",
    "            \"IsInside\",\n",
    "        ]\n",
    "    )\n",
    "test_df[\"ImageID\"] = download_dir + test_df[\"ImageID\"].astype(str)\n",
    "\n",
    "print(os.path.exists(train_df.iloc[0][\"ImageID\"]))\n",
    "print(train_df)\n",
    "print(os.path.exists(test_df.iloc[0][\"ImageID\"]))\n",
    "print(test_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8476756e-38eb-49f9-941b-11a4027fb840",
   "metadata": {},
   "source": [
    "### Train a Few Shot Classifier\n",
    "Similarly, we need to initialize `MultiModalPredictor` with the problem type `few_shot_classification`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "856f92fd-eb23-417e-9968-21db3af9a3b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal import MultiModalPredictor\n",
    "\n",
    "predictor_fs_image = MultiModalPredictor(\n",
    "    problem_type=\"few_shot_classification\",\n",
    "    label=\"LabelName\",  # column name of the label\n",
    "    eval_metric=\"acc\",\n",
    ")\n",
    "predictor_fs_image.fit(train_df)\n",
    "scores = predictor_fs_image.evaluate(test_df, metrics=[\"acc\", \"f1_macro\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1953ec4e-7371-42af-ba29-6102d0d212c1",
   "metadata": {},
   "source": [
    "### Compare to the Default Classifier\n",
    "We can also train a default image classifier and compare to the few shot classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b390066-99fc-4cd8-b542-5e8f01ded73d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal import MultiModalPredictor\n",
    "\n",
    "predictor_default_image = MultiModalPredictor(\n",
    "    problem_type=\"classification\",\n",
    "    label=\"LabelName\",  # column name of the label\n",
    "    eval_metric=\"acc\",\n",
    ")\n",
    "predictor_default_image.fit(train_data=train_df)\n",
    "scores = predictor_default_image.evaluate(test_df, metrics=[\"acc\", \"f1_macro\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcf11ec9-2e9a-42a1-9dc0-69137dad2d5a",
   "metadata": {},
   "source": [
    "As you can see that the `few_shot_classification` performs much better than the default `classification` in image classification as well."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Customization\n",
    "To learn how to customize AutoMM, please refer to [Customize AutoMM](../advanced_topics/customization.ipynb)."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "48e372691fe7f9c0"
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
